"""
Phase 5: Forward Predictions & Out-of-Sample Validation
========================================================
1. Out-of-sample test: train on pre-2007 data, predict 2007-2012 crises
2. Forward vulnerability rankings: 2024 demographics → predicted crisis probability
3. UN projections to 2030/2040 → future vulnerability
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
MULTILATERAL_DIR = ROOT_DIR / "multilateral"
sys.path.insert(0, str(MULTILATERAL_DIR / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


# ── Out-of-sample prediction ─────────────────────────────────────────

def out_of_sample_test(df):
    """Train on pre-2007 data, predict 2007-2012 crises."""
    print("\n" + "=" * 60)
    print("OUT-OF-SAMPLE TEST: Train ≤2006, Predict 2007-2012")
    print("=" * 60)

    dep_vars = ['banking_crisis_onset', 'any_crisis_onset', 'ca_reversal']
    x_vars_ew = ['ca_gdp_lag1', 'fiscal_bal_gdp', 'rgdp_growth', 'kaopen', 'nfa_gdp_lag']
    x_vars_ewz = ['Z_1', 'Z_2', 'Z_3'] + x_vars_ew

    oos_results = []

    for dep_var in dep_vars:
        print(f"\n--- {dep_var} ---")

        for x_vars, model_label in [(x_vars_ew, 'EW only'), (x_vars_ewz, 'EW + Z')]:
            cols = [dep_var] + x_vars + ['iso3', 'year']
            sub = df[cols].dropna()

            train = sub[sub['year'] <= 2006]
            test = sub[(sub['year'] >= 2007) & (sub['year'] <= 2012)]

            if len(train) < 100 or len(test) < 30:
                print(f"  {model_label}: insufficient data (train={len(train)}, test={len(test)})")
                continue

            # Train: PanelGLS
            gls = PanelGLS()
            y_train = train[dep_var].values
            X_train = train[x_vars].values
            try:
                gls.fit(y_train, X_train, train['iso3'].values, train['year'].values)
            except Exception as e:
                print(f"  {model_label}: training failed ({e})")
                continue

            # Predict on test set
            # Panel GLS prediction: ŷ = X·β + country_FE + year_FE
            # For OOS, use only X·β (no FE for new periods)
            X_test = test[x_vars].values
            y_test = test[dep_var].values

            # Simple prediction: X·β (without FE — conservative)
            y_pred = X_test @ gls.beta

            # Classification at threshold
            threshold = y_train.mean()  # use training set base rate
            y_pred_binary = (y_pred > threshold).astype(int)

            # Metrics
            actual_events = y_test.sum()
            predicted_events = y_pred_binary.sum()
            true_positives = ((y_pred_binary == 1) & (y_test == 1)).sum()
            false_positives = ((y_pred_binary == 1) & (y_test == 0)).sum()
            false_negatives = ((y_pred_binary == 0) & (y_test == 1)).sum()
            true_negatives = ((y_pred_binary == 0) & (y_test == 0)).sum()

            precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
            recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

            # AUC (rank-based)
            from scipy import stats as sp_stats
            if actual_events > 0 and actual_events < len(y_test):
                auc = compute_auc(y_test, y_pred)
            else:
                auc = np.nan

            # Correlation between predicted probability and actual outcome
            corr, corr_p = sp_stats.pointbiserialr(y_test, y_pred) if actual_events > 0 else (0, 1)

            result = {
                'dep_var': dep_var,
                'model': model_label,
                'train_n': len(train),
                'test_n': len(test),
                'train_r2': gls.r_squared,
                'actual_events': int(actual_events),
                'predicted_events': int(predicted_events),
                'true_positives': int(true_positives),
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'auc': auc,
                'corr': corr,
                'corr_p': corr_p,
            }
            oos_results.append(result)

            print(f"  {model_label}: Train R²={gls.r_squared:.4f}, "
                  f"AUC={auc:.3f}, Recall={recall:.2f}, Precision={precision:.2f}")
            print(f"    Actual crises: {int(actual_events)}, "
                  f"Predicted: {predicted_events}, TP: {true_positives}")

    # Write OOS table
    if oos_results:
        lines = ["# Out-of-Sample Prediction: Train ≤2006, Test 2007–2012\n"]
        lines.append("| Dep. Variable | Model | Train N | Test N | Train R² | "
                     "AUC | Recall | Precision | F1 | Actual | Predicted |")
        lines.append("|:---|:---|---:|---:|---:|---:|---:|---:|---:|---:|---:|")
        for r in oos_results:
            auc_str = f"{r['auc']:.3f}" if not np.isnan(r['auc']) else "—"
            lines.append(f"| {r['dep_var']} | {r['model']} | {r['train_n']} | "
                        f"{r['test_n']} | {r['train_r2']:.4f} | {auc_str} | "
                        f"{r['recall']:.2f} | {r['precision']:.2f} | {r['f1']:.2f} | "
                        f"{r['actual_events']} | {r['predicted_events']} |")

        lines.append("\n*Threshold = training-set event base rate. "
                     "AUC computed from predicted scores (Xβ without fixed effects).*")

        path = OUT_TABLES / "out_of_sample.md"
        path.write_text('\n'.join(lines))
        print(f"\n  Saved: {path}")

    return oos_results


def compute_auc(y_true, y_scores):
    """Compute AUC from scratch (no sklearn dependency)."""
    # Sort by predicted score descending
    order = np.argsort(-y_scores)
    y_sorted = y_true[order]

    n_pos = y_true.sum()
    n_neg = len(y_true) - n_pos

    if n_pos == 0 or n_neg == 0:
        return np.nan

    # Wilcoxon-Mann-Whitney statistic
    tp = 0
    auc_sum = 0
    for i in range(len(y_sorted)):
        if y_sorted[i] == 1:
            tp += 1
        else:
            auc_sum += tp

    return auc_sum / (n_pos * n_neg)


# ── Forward vulnerability rankings ──────────────────────────────────

def forward_vulnerability(df):
    """Rank countries by predicted crisis vulnerability using latest data."""
    print("\n" + "=" * 60)
    print("FORWARD VULNERABILITY RANKINGS (2024 Demographics)")
    print("=" * 60)

    # Train on full sample
    x_vars = ['Z_1', 'Z_2', 'Z_3', 'fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'kaopen']
    dep_vars = ['banking_crisis_onset', 'any_crisis_onset', 'ca_reversal']

    rankings = {}

    for dep_var in dep_vars:
        cols = [dep_var] + x_vars + ['iso3', 'year']
        sub = df[cols].dropna()
        if len(sub) < 100:
            continue

        gls = PanelGLS()
        try:
            gls.fit(sub[dep_var].values, sub[x_vars].values,
                    sub['iso3'].values, sub['year'].values)
        except Exception:
            continue

        # Get latest observation per country (2020-2024 window)
        latest = df[df['year'] >= 2020].copy()
        latest = latest.sort_values('year').groupby('iso3').last().reset_index()
        latest = latest[['iso3', 'year'] + x_vars].dropna()

        if len(latest) == 0:
            continue

        # Predicted vulnerability (Xβ only, no FE)
        X_latest = latest[x_vars].values
        latest[f'pred_{dep_var}'] = X_latest @ gls.beta

        # Country fixed effects (average for each country from training)
        country_fe = {}
        for iso3 in sub['iso3'].unique():
            mask = sub['iso3'] == iso3
            y_i = sub.loc[mask, dep_var].values
            X_i = sub.loc[mask, x_vars].values
            resid_i = y_i - X_i @ gls.beta
            country_fe[iso3] = resid_i.mean()

        latest[f'fe_{dep_var}'] = latest['iso3'].map(country_fe).fillna(0)
        latest[f'score_{dep_var}'] = latest[f'pred_{dep_var}'] + latest[f'fe_{dep_var}']

        rankings[dep_var] = latest[['iso3', 'year', f'score_{dep_var}',
                                      f'pred_{dep_var}', f'fe_{dep_var}']].copy()
        rankings[dep_var] = rankings[dep_var].sort_values(f'score_{dep_var}', ascending=False)

        print(f"\n  {dep_var} — Top 20 Most Vulnerable:")
        top20 = rankings[dep_var].head(20)
        for _, row in top20.iterrows():
            print(f"    {row['iso3']}: score={row[f'score_{dep_var}']:.4f} "
                  f"(Xβ={row[f'pred_{dep_var}']:.4f}, FE={row[f'fe_{dep_var}']:.4f})")

        print(f"\n  {dep_var} — Bottom 10 (Least Vulnerable):")
        bot10 = rankings[dep_var].tail(10)
        for _, row in bot10.iterrows():
            print(f"    {row['iso3']}: score={row[f'score_{dep_var}']:.4f}")

    # Write vulnerability ranking table
    if 'banking_crisis_onset' in rankings:
        rk = rankings['banking_crisis_onset'].copy()
        rk = rk.rename(columns={
            f'score_banking_crisis_onset': 'score',
            f'pred_banking_crisis_onset': 'xbeta',
            f'fe_banking_crisis_onset': 'fe',
        })
        rk['rank'] = range(1, len(rk) + 1)

        # Also get CA reversal ranking for comparison
        if 'ca_reversal' in rankings:
            rk_ca = rankings['ca_reversal'][['iso3', f'score_ca_reversal']].copy()
            rk_ca.columns = ['iso3', 'ca_score']
            rk = rk.merge(rk_ca, on='iso3', how='left')

        lines = ["# Forward Vulnerability Rankings (2024 Demographics)\n"]
        lines.append("Countries ranked by predicted banking crisis probability (Xβ + country FE).\n")

        # Top 25 most vulnerable
        lines.append("## Most Vulnerable Countries\n")
        if 'ca_score' in rk.columns:
            lines.append("| Rank | Country | Crisis Score | Xβ Component | Country FE | CA Reversal Score |")
            lines.append("|---:|:---|---:|---:|---:|---:|")
            for _, row in rk.head(25).iterrows():
                ca_str = f"{row['ca_score']:.4f}" if pd.notna(row.get('ca_score')) else "—"
                lines.append(f"| {row['rank']} | {row['iso3']} | {row['score']:.4f} | "
                           f"{row['xbeta']:.4f} | {row['fe']:.4f} | {ca_str} |")
        else:
            lines.append("| Rank | Country | Crisis Score | Xβ Component | Country FE |")
            lines.append("|---:|:---|---:|---:|---:|")
            for _, row in rk.head(25).iterrows():
                lines.append(f"| {row['rank']} | {row['iso3']} | {row['score']:.4f} | "
                           f"{row['xbeta']:.4f} | {row['fe']:.4f} |")

        # Bottom 10 least vulnerable
        lines.append("\n## Least Vulnerable Countries\n")
        lines.append("| Rank | Country | Crisis Score | Xβ Component | Country FE |")
        lines.append("|---:|:---|---:|---:|---:|")
        for _, row in rk.tail(10).iterrows():
            lines.append(f"| {row['rank']} | {row['iso3']} | {row['score']:.4f} | "
                       f"{row['xbeta']:.4f} | {row['fe']:.4f} |")

        lines.append(f"\n*Rankings based on {len(rk)} countries with complete 2020–2024 data. "
                     "Score = Xβ (demographic + macro fundamentals) + country fixed effect.*")

        path = OUT_TABLES / "vulnerability_rankings.md"
        path.write_text('\n'.join(lines))
        print(f"\n  Saved: {path}")

    return rankings


# ── UN Projections forward vulnerability ─────────────────────────────

def forward_projections(df):
    """Use demographic trends to project future vulnerability shifts."""
    print("\n" + "=" * 60)
    print("DEMOGRAPHIC TRAJECTORY: WHO IS BECOMING MORE VULNERABLE?")
    print("=" * 60)

    # Compare early 2000s demographics to 2020s — who shifted most?
    early = df[(df['year'] >= 2000) & (df['year'] <= 2005)].groupby('iso3')[
        ['Z_1', 'Z_2', 'Z_3', 'old_dep', 'youth_dep']].mean()
    late = df[(df['year'] >= 2018) & (df['year'] <= 2024)].groupby('iso3')[
        ['Z_1', 'Z_2', 'Z_3', 'old_dep', 'youth_dep']].mean()

    common = early.index.intersection(late.index)
    if len(common) < 20:
        print("  Insufficient overlapping countries for trajectory analysis")
        return

    delta = (late.loc[common] - early.loc[common]).copy()
    delta.columns = [f'd_{c}' for c in delta.columns]
    delta = delta.reset_index()

    # Countries with biggest increase in old_dep (aging fastest)
    delta = delta.sort_values('d_old_dep', ascending=False)
    print(f"\n  Fastest aging (Δold_dep, 2000s → 2020s):")
    for _, row in delta.head(15).iterrows():
        print(f"    {row['iso3']}: Δold_dep={row['d_old_dep']:+.4f}, "
              f"ΔZ₁={row['d_Z_1']:+.4f}")

    # Countries with biggest decrease in youth_dep (youth bulge declining)
    delta_youth = delta.sort_values('d_youth_dep', ascending=True)
    print(f"\n  Fastest youth decline (Δyouth_dep):")
    for _, row in delta_youth.head(15).iterrows():
        print(f"    {row['iso3']}: Δyouth_dep={row['d_youth_dep']:+.4f}")

    # Project forward: extrapolate 20-year trends
    # For each country, if old_dep trend continues, where will they be in 2040?
    delta['proj_old_dep_2040'] = late.loc[common, 'old_dep'].values + delta['d_old_dep'].values
    delta['proj_youth_dep_2040'] = late.loc[common, 'youth_dep'].values + delta['d_youth_dep'].values
    delta['current_old_dep'] = late.loc[common, 'old_dep'].values
    delta['current_youth_dep'] = late.loc[common, 'youth_dep'].values

    # Risk trajectory: countries moving FROM young (high reversal risk)
    # TO old (different risk profile — bank profitability)
    delta['risk_transition'] = delta['d_old_dep'] - delta['d_youth_dep']
    delta = delta.sort_values('risk_transition', ascending=False)

    lines = ["# Demographic Trajectory and Future Vulnerability\n"]
    lines.append("Countries ranked by demographic shift (2000–05 → 2018–24).\n")

    lines.append("## Fastest Aging Countries (Risk Profile Shifting)\n")
    lines.append("| Country | Current old_dep | Δold_dep | Δyouth_dep | Projected old_dep 2040 |")
    lines.append("|:---|---:|---:|---:|---:|")
    for _, row in delta.head(20).iterrows():
        lines.append(f"| {row['iso3']} | {row['current_old_dep']:.3f} | "
                   f"{row['d_old_dep']:+.4f} | {row['d_youth_dep']:+.4f} | "
                   f"{row['proj_old_dep_2040']:.3f} |")

    lines.append("\n## Still-Young Countries (Elevated Reversal Risk)\n")
    still_young = delta[delta['current_youth_dep'] > delta['current_youth_dep'].median()]
    still_young = still_young.sort_values('current_youth_dep', ascending=False)
    lines.append("| Country | Current youth_dep | Δyouth_dep | Current old_dep |")
    lines.append("|:---|---:|---:|---:|")
    for _, row in still_young.head(15).iterrows():
        lines.append(f"| {row['iso3']} | {row['current_youth_dep']:.3f} | "
                   f"{row['d_youth_dep']:+.4f} | {row['current_old_dep']:.3f} |")

    lines.append(f"\n*Projections assume linear extrapolation of 2000–2024 demographic trends. "
                 f"Based on {len(common)} countries with data in both periods.*")

    path = OUT_TABLES / "demographic_trajectory.md"
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


# ── Main ─────────────────────────────────────────────────────────────

def main():
    print("=" * 70)
    print("PHASE 5: FORWARD PREDICTIONS & OUT-OF-SAMPLE VALIDATION")
    print("=" * 70)

    df = pd.read_csv(DATA / "crises_panel.csv")
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    # 1. Out-of-sample test
    oos_results = out_of_sample_test(df)

    # 2. Forward vulnerability rankings
    rankings = forward_vulnerability(df)

    # 3. Demographic trajectory projections
    forward_projections(df)

    print("\n" + "=" * 70)
    print("PHASE 5 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
